﻿using Extented.UI.Core.Native;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Windows;
using System.Windows.Media;

namespace Extented.UI.Core.Helpers
{
    public static class RenderTreeHelper
    {
        public static FrameworkRenderElementContext HitTest(FrameworkRenderElementContext context, Point relativePoint)
        {
            FrameworkRenderElementContext result = null;
            HitTest(context,
                frec => RenderHitTestFilterBehavior.Continue,
                frec => { result = frec; return RenderHitTestResultBehavior.Continue; },
                relativePoint);
            return result;
        }
        public static FrameworkRenderElementContext HitTest(FrameworkRenderElementContext context, Point relativePoint, Func<FrameworkRenderElementContext, bool> predicate)
        {
            FrameworkRenderElementContext result = null;
            HitTest(context,
                frec => RenderHitTestFilterBehavior.Continue,
                frec => { if (predicate(frec)) { result = frec; return RenderHitTestResultBehavior.Stop; } return RenderHitTestResultBehavior.Continue; },
                relativePoint);
            return result;
        }
        public static void HitTest(FrameworkRenderElementContext context,
                                    Func<FrameworkRenderElementContext, RenderHitTestFilterBehavior> filterCallback,
                                    Func<FrameworkRenderElementContext, RenderHitTestResultBehavior> resultCallback,
                                    Point relativePoint)
        {
            List<FrameworkRenderElementContext> contexts = new List<FrameworkRenderElementContext>();
            List<FrameworkRenderElementContext> results = new List<FrameworkRenderElementContext>();
            var rootTransform = RenderTransformValue(RenderAncestors(context).LastOrDefault() ?? context);
            rootTransform.Invert();
            relativePoint = rootTransform.Transform(relativePoint);
            contexts.Add(context);
            while (contexts.Count > 0)
            {
                var current = contexts[0];
                contexts.RemoveAt(0);
                var hitTestResult = current.HitTest(GetMatrixTransformToDescendant(context, current).Transform(relativePoint));
                var filter = filterCallback(current);
                if (!filter.HasFlag(RenderHitTestFilterBehavior.Continue))
                    return;
                if (!filter.HasFlag(RenderHitTestFilterBehavior.ContinueSkipSelf) && hitTestResult)
                    results.Add(current);
                if (!filter.HasFlag(RenderHitTestFilterBehavior.ContinueSkipChildren))
                    contexts.InsertRange(0, RenderChildren(current));
            }
            foreach (var result in results)
                if (resultCallback(result) == RenderHitTestResultBehavior.Stop)
                    return;
        }
        public static Transform TransformToAncestor(FrameworkRenderElementContext element, FrameworkRenderElementContext ancestor)
        {
            return new MatrixTransform(GetMatrixTransformToAncestor(element, ancestor));
        }
        public static Transform TransformToDescendant(FrameworkRenderElementContext element, FrameworkRenderElementContext descendant)
        {
            return new MatrixTransform(GetMatrixTransformToDescendant(element, descendant));
        }
        public static Transform TransformToRoot(FrameworkRenderElementContext element)
        {
            var root = RenderAncestors(element).LastOrDefault() ?? element;
            return new MatrixTransform(GetMatrixTransformToAncestor(element, root) * RenderTransformValue(root));
        }
        static Matrix GetMatrixTransformToAncestor(FrameworkRenderElementContext element, FrameworkRenderElementContext ancestor)
        {
            Matrix result = Matrix.Identity;
            if (element == ancestor)
                return result;
            result = RenderTransformValue(element);
            bool isAncestor = false;
            foreach (var anc in RenderAncestors(element))
            {
                if (anc == ancestor)
                {
                    isAncestor = true;
                    break;
                }
                result *= RenderTransformValue(anc);
            }
            if (!isAncestor)
                throw new ArgumentException("ancestor");
            return result;
        }
        static Matrix RenderTransformValue(FrameworkRenderElementContext element)
        {
            if (element == null)
                throw new ArgumentNullException("element");
            if (element.RenderTransform == null)
                return Matrix.Identity;
            return element.RenderTransform.Value;
        }
        static Matrix GetMatrixTransformToDescendant(FrameworkRenderElementContext element, FrameworkRenderElementContext descendant)
        {
            var matrix = GetMatrixTransformToAncestor(descendant, element);
            matrix.Invert();
            return matrix;
        }
        public static IEnumerable<FrameworkRenderElementContext> RenderAncestors(FrameworkRenderElementContext context)
        {
            while (context != null)
            {
                context = context.Parent;
                if (context == null)
                    break;
                yield return context;
            }
        }
        public static IEnumerable<FrameworkRenderElementContext> RenderDescendants(FrameworkRenderElementContext context)
        {
            foreach (var child in RenderChildren(context))
            {
                yield return child;
                foreach (var descendant in RenderDescendants(child))
                    yield return descendant;
            }
        }
        static IEnumerable<FrameworkRenderElementContext> RenderChildren(FrameworkRenderElementContext context)
        {
            var iContext = (IFrameworkRenderElementContext)context;
            for (int i = 0; i < iContext.RenderChildrenCount; i++)
            {
                yield return iContext.GetRenderChild(i);
            }
        }
        public static T FindDescendant<T>(FrameworkRenderElementContext context) where T : FrameworkRenderElementContext
        {
            return (T)FindDescendant(context, x => x is T);
        }
        public static FrameworkRenderElementContext FindDescendant(FrameworkRenderElementContext context, Func<FrameworkRenderElementContext, bool> predicate)
        {
            return RenderDescendants(context).FirstOrDefault(x => predicate(x));
        }
    }
}
